Spectral GNNs leverage the spectral theory of graphs to process graph-structured data. This tutorial will explore how to apply Spectral GNNs.
!pip install torch torchvision
!pip install networkx matplotlib scikit-learn
!pip install plotly
!pip install torch_geometric
!pip install -U kaleido
# Optional dependencies:
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
# Install required packages.
import plotly.io as pio
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)
# Helper function for visualization.
%matplotlib inline
import matplotlib.pyplot as plt
from IPython.display import Image, display
from sklearn.manifold import TSNE
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from sklearn.neighbors import kneighbors_graph
import networkx as nx
import plotly.graph_objs as go
import plotly.io as pio
def visualize(h, color):
z = TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())
plt.figure(figsize=(10,10))
plt.xticks([])
plt.yticks([])
plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap="Set2")
plt.show()
Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.1.0+cu121) Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (0.16.0+cu121) Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.13.1) Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch) (4.9.0) Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12) Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.2.1) Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.3) Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2023.6.0) Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.1.0) Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision) (1.23.5) Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from torchvision) (2.31.0) Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision) (9.4.0) Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5) Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision) (3.3.2) Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision) (3.6) Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision) (2.0.7) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision) (2024.2.2) Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0) Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (3.2.1) Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (3.7.1) Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (1.2.2) Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.2.0) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (4.47.2) Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.4.5) Requirement already satisfied: numpy>=1.20 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.23.5) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (23.2) Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (9.4.0) Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (3.1.1) Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (2.8.2) Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.11.4) Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.3.2) Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (3.2.0) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib) (1.16.0) Requirement already satisfied: plotly in /usr/local/lib/python3.10/dist-packages (5.15.0) Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from plotly) (8.2.3) Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from plotly) (23.2) Requirement already satisfied: torch_geometric in /usr/local/lib/python3.10/dist-packages (2.4.0) Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (4.66.1) Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (1.23.5) Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (1.11.4) Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (3.1.3) Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (2.31.0) Requirement already satisfied: pyparsing in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (3.1.1) Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (1.2.2) Requirement already satisfied: psutil>=5.8.0 in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (5.9.5) Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch_geometric) (2.1.5) Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->torch_geometric) (3.3.2) Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->torch_geometric) (3.6) Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->torch_geometric) (2.0.7) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->torch_geometric) (2024.2.2) Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->torch_geometric) (1.3.2) Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->torch_geometric) (3.2.0) Requirement already satisfied: kaleido in /usr/local/lib/python3.10/dist-packages (0.2.1) Looking in links: https://data.pyg.org/whl/torch-2.1.0+cu121.html Requirement already satisfied: pyg_lib in /usr/local/lib/python3.10/dist-packages (0.3.1+pt21cu121) Requirement already satisfied: torch_scatter in /usr/local/lib/python3.10/dist-packages (2.1.2+pt21cu121) Requirement already satisfied: torch_sparse in /usr/local/lib/python3.10/dist-packages (0.6.18+pt21cu121) Requirement already satisfied: torch_cluster in /usr/local/lib/python3.10/dist-packages (1.6.3+pt21cu121) Requirement already satisfied: torch_spline_conv in /usr/local/lib/python3.10/dist-packages (1.2.2+pt21cu121) Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from torch_sparse) (1.11.4) Requirement already satisfied: numpy<1.28.0,>=1.21.6 in /usr/local/lib/python3.10/dist-packages (from scipy->torch_sparse) (1.23.5) 2.1.0+cu121
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
dataset = Planetoid(root='data/Planetoid', name='Cora', transform=NormalizeFeatures())
print()
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
data = dataset[0] # Get the first graph object.
print()
print(data)
print('===========================================================================================================')
# Gather some statistics about the graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')
Dataset: Cora(): ====================== Number of graphs: 1 Number of features: 1433 Number of classes: 7 Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708]) =========================================================================================================== Number of nodes: 2708 Number of edges: 10556 Average node degree: 3.90 Number of training nodes: 140 Training node label rate: 0.05 Has isolated nodes: False Has self-loops: False Is undirected: True
def plot_interactive(domain, signal, title):
presicion = 4 # how many decimal places to include in the color vector signal
rounded_signal = np.round(signal, presicion)
trace = go.Scatter3d(
x=domain[:, 0], y=domain[:, 1], z=domain[:, 2],
mode='markers',
marker=dict(
size=3,
color=rounded_signal, # Color by signal values
colorscale='Inferno', # Color scale
opacity=0.8,
colorbar=dict(title='Value') # Add a color bar
)
)
layout = go.Layout(
title=dict(text=title, y=0.9, x=0.5, xanchor='center', yanchor='top'),
margin=dict(l=0, r=0, b=0, t=40), # Adjust top margin to fit title
scene=dict(
xaxis=dict(title='X-axis'),
yaxis=dict(title='Y-axis'),
zaxis=dict(title='Z-axis')
),
width=800, # Adjust width
height=400 # Adjust height
)
fig = go.Figure(data=[trace], layout=layout)
static_image = pio.to_image(fig, format='png', width=800, height=400, scale=2)
display(Image(static_image))
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def plot3d_graph_plotly(points, G,title='3D Visualization'):
edge_x = []
edge_y = []
edge_z = []
for edge in G.edges():
x0, y0, z0 = points[edge[0]]
x1, y1, z1 = points[edge[1]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
edge_z.extend([z0, z1, None])
edge_trace = go.Scatter3d(
x=edge_x, y=edge_y, z=edge_z,
line=dict(width=2, color='blue'),
mode='lines')
node_trace = go.Scatter3d(
x=points[:, 0], y=points[:, 1], z=points[:, 2],
mode='markers',
marker=dict(size=5, color='red', opacity=0.8)
)
fig = go.Figure(data=[edge_trace, node_trace])
fig.update_layout(scene=dict(
xaxis_title='X',
yaxis_title='Y',
zaxis_title='Z'),
width=800, # Adjust width
height=400, # Adjust height
title=title)
static_image = pio.to_image(fig, format='png', width=800, height=400, scale=2)
display(Image(static_image))
def create_torus(r=0.7,R=2, n_samples=5000):
theta = np.random.uniform(low=0, high = 2*np.pi,size = n_samples)
phi = np.random.uniform(low=0, high = 2*np.pi,size = n_samples)
x = (R + r*np.cos(phi))*np.cos(theta)
y = (R + r*np.cos(phi))*np.sin(theta)
z = r*np.sin(phi)
return np.column_stack((x,y,z))
points = create_torus()
plot_interactive(points,points[:,2],"Torus")
# Create a k-nearest neighbors graph from the points
A = kneighbors_graph(points, n_neighbors=20, mode='distance', include_self=True)
G = nx.from_scipy_sparse_array(A)
plot3d_graph_plotly(points, G,title = '3D Visualization of K-NN Graph on Torus')
laplacian = nx.laplacian_matrix(G).toarray()
print(laplacian.shape)
eigenvalues, eigenvectors = np.linalg.eigh(laplacian)
(5000, 5000)
plot_interactive(points,eigenvectors[:,0], "Interactive Manifold with 1st Laplacian Eigenvector")
plot_interactive(points,eigenvectors[:,1], "Interactive Manifold with 2nd Laplacian Eigenvector")
plot_interactive(points,eigenvectors[:,2], "Interactive Manifold with 3rd Laplacian Eigenvector")
plot_interactive(points,eigenvectors[:,6], "Interactive Manifold with 7th Laplacian Eigenvector") # play with colors - should look constant (showing numerical errors)
def create_mobius_strip(width=0.2, n_samples_theta=100, n_samples_phi=30):
theta = np.linspace(0, 2 * np.pi, n_samples_theta)
phi = np.linspace(-width, width, n_samples_phi)
theta, phi = np.meshgrid(theta, phi)
x = (1 + phi/2 * np.cos(theta/2)) * np.cos(theta)
y = (1 + phi/2 * np.cos(theta/2)) * np.sin(theta)
z = phi/2 * np.sin(theta/2)
return np.column_stack((x.ravel(), y.ravel(), z.ravel()))
points = create_mobius_strip()
plot_interactive(points,points[:,2],"Möbius Strip")
# Create a k-nearest neighbors graph from the points
A = kneighbors_graph(points, n_neighbors=30, mode='distance', include_self=True)
G = nx.from_scipy_sparse_array(A)
plot3d_graph_plotly(points, G)
laplacian = nx.laplacian_matrix(G).toarray()
print(laplacian.shape)
(3000, 3000)
eigenvalues, eigenvectors = np.linalg.eigh(laplacian)
plot_interactive(points,eigenvectors[:,0], "Interactive Manifold with 1st Laplacian Eigenvector")
plot_interactive(points,eigenvectors[:,1], "Interactive Manifold with 2nd Laplacian Eigenvector")
plot_interactive(points,eigenvectors[:,2], "Interactive Manifold with 3rd Laplacian Eigenvector")
plot_interactive(points,eigenvectors[:,3], "Interactive Manifold with 4th Laplacian Eigenvector")
plot_interactive(points,eigenvectors[:,12], "Interactive Manifold with 13th Laplacian Eigenvector")
def sinc(x):
f = np.sin(x)/x
f[x == 0] = 1
return f
def signal_func(points):
sig = sinc(points[:,0]*10)*sinc(points[:,2]*60)
# sig[points[:,0] >= 0] = 0
# sig[points[:,1] <= 0] = 0
return sig
## Build signal
graph_signal_clean = signal_func(points)
graph_signal = graph_signal_clean + np.random.normal(loc=0, scale=0.7,size=eigenvectors[:,12].shape)
plot_interactive(points,graph_signal, f"Noisy Graph Signal")
plot_interactive(points,graph_signal_clean, "Clean Graph Signal")
<ipython-input-42-cbcb12c6cc72>:2: RuntimeWarning: invalid value encountered in divide
Choose filter $f(\lambda)$ in the frequency domain and apply it using GFT.
$$\tilde{x} = x*g=\mathcal{F}^{-1}\{X^F\bullet G^F\}$$# define filter frequency response
def f_bpf(l,lambda_min=0,lambda_max=1.2):
f = np.ones(l.shape)
f[l > lambda_max] = 0
f[l < lambda_min] = 0
return f
# plot frequecny response
N = 1000
freq_vec = np.linspace(eigenvalues[0],eigenvalues[-1],N)
fil = f_bpf(freq_vec);
fig,ax = plt.subplots(1,1)
ax.plot(freq_vec,fil)
ax.set_xlabel(r"$\lambda$ - Graph Frequency")
ax.set_ylabel('Amplitude')
ax.set_title(r"$f(\lambda)$");
# Apply Filter
graph_signal_f = eigenvectors.T@graph_signal # GFT
signal_filtered_f = graph_signal_f*f_bpf(eigenvalues) # apply filter
signal_filtered = eigenvectors@signal_filtered_f # inverse GFT
# plot signal
plot_interactive(points,signal_filtered, f"Graph Signal - Naive Filtered")
Smoothen the frequency response to encorage locality.
from scipy.interpolate import CubicSpline
# smoothen frequency response with cubic interpolation
hop = 20
f_pre_interp = f_bpf(eigenvalues[0::hop])
f_interp = CubicSpline(eigenvalues[0::hop],f_pre_interp, bc_type='natural')
freq_vec = np.linspace(eigenvalues[0],eigenvalues[-1],N)
fil = f_bpf(freq_vec);
fig,ax = plt.subplots(1,1)
ax.plot(freq_vec,fil, label = 'original')
ax.plot(freq_vec,f_interp(freq_vec),label = 'interpolated')
ax.set_xlabel(r"$\lambda$ - Graph Frequency")
ax.set_ylabel('Amplitude')
ax.set_title(r"$f(\lambda)$");
# Apply Filter
signal_filtered_f = graph_signal_f*f_interp(eigenvalues) # apply filter
signal_filtered = eigenvectors@signal_filtered_f # inverse GFT
# plot signal
plot_interactive(points,signal_filtered, f"Graph Signal - Smooth Spectral Multipliers Filtered")
# fit polynomial filter to match the naive intension
fig,ax = plt.subplots(1,1)
ax.plot(freq_vec,fil, label = 'Original')
ax.set_xlabel(r"$\lambda$ - Graph Frequency")
ax.set_ylabel('Amplitude')
ax.set_title(r"$f(\lambda)$");
p_order = np.arange(6,22,4)
for p in p_order:
p_coeffs = np.polyfit(eigenvalues, f_bpf(eigenvalues), p)
mymodel = np.poly1d(p_coeffs)
ax.plot(freq_vec,mymodel(freq_vec), label = f'{p}-order polynomial')
ax.legend()
/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py:3553: RankWarning: Polyfit may be poorly conditioned
<matplotlib.legend.Legend at 0x7a2590fd7a90>
def graph_poly_filt(L,signal,p_coeffs):
# initalize filtered signal
signal_filt = np.zeros(signal.shape)
Lx_prev = signal
# polynom coeffs order is from p[0]x^deg to p[deg]x^0
for p in np.flip(p_coeffs):
signal_filt += p*Lx_prev
# calculate L^i*x for the next iteration
Lx_prev = L@Lx_prev
return signal_filt
# Apply Filter
p_order = np.append(np.arange(1,5),10)
for p in p_order:
p_coeffs = np.polyfit(eigenvalues, f_bpf(eigenvalues), p)
signal_filtered = graph_poly_filt(laplacian,graph_signal,p_coeffs)
poly_model = np.poly1d(p_coeffs)
# plot signal
plot_interactive(points,signal_filtered, f"Graph Signal - {p}-order Polynomial filter")
Number of graphs: 1, Number of features: 1433 (bag of words), Number of classes: 7
Number of nodes: 2708, Number of edges: 10556, Average node degree: 3.90,
Number of training nodes: 140, Training node label rate: 0.05 Has isolated nodes: False Has self-loops: False Is undirected: True
# Dimensionality reduction with t-SNE
tsne = TSNE(n_components=2, perplexity=30, n_iter=300)
x_tsne = tsne.fit_transform(data.x.detach().numpy())
# Get the edge index in COO format
edge_index = data.edge_index.numpy()
# Plot
plt.figure(figsize=(10, 8))
for i in range(dataset.num_classes):
idx = data.y == i
plt.scatter(x_tsne[idx, 0], x_tsne[idx, 1], label=f'Class {i}', s=20) # s is the size of the point
# Optionally, you can plot a subset of edges if needed
if edge_index.shape[1] < 1000: # Plot edges for smaller graphs
for i in range(edge_index.shape[1]):
source = edge_index[0, i]
target = edge_index[1, i]
plt.plot(x_tsne[[source, target], 0], x_tsne[[source, target], 1], c='black', alpha=0.5)
plt.xlabel('TSNE Component 1')
plt.ylabel('TSNE Component 2')
plt.title('Cora Citation Network (t-SNE visualization)')
plt.legend()
plt.show()
We define a Spectral GNN layer that operates in the spectral domain.(maybe)
import torch
import torch.nn as nn
import torch.nn.functional as F
class SpectralGNNLayer(nn.Module):
def __init__(self, eigenvectors, d_cutoff, in_channels, out_channels):
super(SpectralGNNLayer, self).__init__()
self.d_cutoff = d_cutoff
self.in_channels = in_channels
self.out_channels = out_channels
self.eigenvectors = torch.tensor(eigenvectors[:, :d_cutoff], dtype=torch.float32) # Shape: [N, D]
self.theta = nn.Parameter(torch.randn(d_cutoff, in_channels, out_channels)) # Shape: [D, F_in, F_out]
nn.init.xavier_uniform_(self.theta)
def forward(self, x):
# Fourier transform
x_spectral = self.eigenvectors.T @ x # Shape: [D, N] @ [N, F_in] -> [D, F_in]
# Introduce a singleton dimension for input channels
x_spectral = x_spectral.unsqueeze(-1) # Shape: [D, F_in, 1]
# Apply spectral filters
x_spectral = torch.mul(x_spectral, self.theta) # Shape: [D, F_in, F_out]
#sum over input channels
filtered = torch.sum(x_spectral,dim=1) # Shape: [D, F_out]
x_transformed = self.eigenvectors @ filtered # Shape: [N, Fout]
return x_transformed
class TwoLayerSpectralGNN(nn.Module):
def __init__(self, eigenvectors, d_cutoff, num_features, num_hidden_features, num_output_classes):
super(TwoLayerSpectralGNN, self).__init__()
# Initialize both layers with the eigenvectors
self.layer1 = SpectralGNNLayer(eigenvectors, d_cutoff, num_features, num_hidden_features)
self.layer2 = SpectralGNNLayer(eigenvectors, d_cutoff, num_hidden_features, num_output_classes)
def forward(self, x):
# First layer with ReLU activation
x = self.layer1(x)
x = F.relu(x)
# Second layer
x = self.layer2(x)
return x
from torch_geometric.utils import to_scipy_sparse_matrix, add_self_loops
# Add self-loops to the adjacency matrix
edge_index, _ = add_self_loops(data.edge_index, num_nodes=data.num_nodes)
# Compute the adjacency matrix as a sparse matrix
adjacency_matrix = to_scipy_sparse_matrix(edge_index)
# Compute the degree matrix
degrees = adjacency_matrix.sum(axis=0).A1 # Sum of each row
D = torch.diag(torch.pow(torch.tensor(degrees, dtype=torch.float32), -0.5))
# Convert the adjacency matrix to a torch tensor
A = torch.tensor(adjacency_matrix.todense(), dtype=torch.float32)
# Compute the normalized Laplacian
I = torch.eye(data.num_nodes) # Identity matrix
L = I - torch.matmul(torch.matmul(D, A), D)
cora_eigenvalues, cora_eigenvectors = np.linalg.eigh(L)
model = TwoLayerSpectralGNN(cora_eigenvectors, d_cutoff = 100, num_features=data.x.shape[1], num_hidden_features=16, num_output_classes=7)
print("num of parameters is ", count_parameters(model))
print("1433*16*5 + 16*7*5")
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()
def train():
model.train()
optimizer.zero_grad() # Clear gradients.
out = model(data.x) # Perform a single forward pass.
loss = criterion(out[data.train_mask], data.y[data.train_mask]) # Compute the loss solely based on the training nodes.
loss.backward() # Derive gradients.
optimizer.step() # Update parameters based on gradients.
return loss
def test():
model.eval()
out = model(data.x)
pred = out.argmax(dim=1) # Use the class with highest probability.
test_correct = pred[data.test_mask] == data.y[data.test_mask] # Check against ground-truth labels.
test_acc = int(test_correct.sum()) / int(data.test_mask.sum()) # Derive ratio of correct predictions.
return test_acc
test_acc = test()
print(f'Accuracy before training: {test_acc:.4f}')
for epoch in range(1, 101):
loss = train()
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
num of parameters is 2304000 1433*16*5 + 16*7*5 Accuracy before training: 0.1690 Epoch: 001, Loss: 1.9459 Epoch: 002, Loss: 1.9437 Epoch: 003, Loss: 1.9391 Epoch: 004, Loss: 1.9321 Epoch: 005, Loss: 1.9223 Epoch: 006, Loss: 1.9096 Epoch: 007, Loss: 1.8941 Epoch: 008, Loss: 1.8757 Epoch: 009, Loss: 1.8547 Epoch: 010, Loss: 1.8314 Epoch: 011, Loss: 1.8062 Epoch: 012, Loss: 1.7793 Epoch: 013, Loss: 1.7511 Epoch: 014, Loss: 1.7215 Epoch: 015, Loss: 1.6909 Epoch: 016, Loss: 1.6594 Epoch: 017, Loss: 1.6269 Epoch: 018, Loss: 1.5935 Epoch: 019, Loss: 1.5592 Epoch: 020, Loss: 1.5240 Epoch: 021, Loss: 1.4881 Epoch: 022, Loss: 1.4518 Epoch: 023, Loss: 1.4152 Epoch: 024, Loss: 1.3786 Epoch: 025, Loss: 1.3419 Epoch: 026, Loss: 1.3054 Epoch: 027, Loss: 1.2690 Epoch: 028, Loss: 1.2328 Epoch: 029, Loss: 1.1968 Epoch: 030, Loss: 1.1613 Epoch: 031, Loss: 1.1264 Epoch: 032, Loss: 1.0925 Epoch: 033, Loss: 1.0596 Epoch: 034, Loss: 1.0278 Epoch: 035, Loss: 0.9974 Epoch: 036, Loss: 0.9684 Epoch: 037, Loss: 0.9408 Epoch: 038, Loss: 0.9147 Epoch: 039, Loss: 0.8899 Epoch: 040, Loss: 0.8661 Epoch: 041, Loss: 0.8431 Epoch: 042, Loss: 0.8209 Epoch: 043, Loss: 0.7996 Epoch: 044, Loss: 0.7793 Epoch: 045, Loss: 0.7601 Epoch: 046, Loss: 0.7420 Epoch: 047, Loss: 0.7251 Epoch: 048, Loss: 0.7093 Epoch: 049, Loss: 0.6946 Epoch: 050, Loss: 0.6809 Epoch: 051, Loss: 0.6681 Epoch: 052, Loss: 0.6561 Epoch: 053, Loss: 0.6447 Epoch: 054, Loss: 0.6341 Epoch: 055, Loss: 0.6241 Epoch: 056, Loss: 0.6147 Epoch: 057, Loss: 0.6060 Epoch: 058, Loss: 0.5977 Epoch: 059, Loss: 0.5898 Epoch: 060, Loss: 0.5822 Epoch: 061, Loss: 0.5750 Epoch: 062, Loss: 0.5680 Epoch: 063, Loss: 0.5614 Epoch: 064, Loss: 0.5551 Epoch: 065, Loss: 0.5492 Epoch: 066, Loss: 0.5434 Epoch: 067, Loss: 0.5380 Epoch: 068, Loss: 0.5327 Epoch: 069, Loss: 0.5276 Epoch: 070, Loss: 0.5228 Epoch: 071, Loss: 0.5181 Epoch: 072, Loss: 0.5137 Epoch: 073, Loss: 0.5094 Epoch: 074, Loss: 0.5053 Epoch: 075, Loss: 0.5013 Epoch: 076, Loss: 0.4975 Epoch: 077, Loss: 0.4938 Epoch: 078, Loss: 0.4903 Epoch: 079, Loss: 0.4869 Epoch: 080, Loss: 0.4835 Epoch: 081, Loss: 0.4803 Epoch: 082, Loss: 0.4771 Epoch: 083, Loss: 0.4741 Epoch: 084, Loss: 0.4711 Epoch: 085, Loss: 0.4682 Epoch: 086, Loss: 0.4654 Epoch: 087, Loss: 0.4626 Epoch: 088, Loss: 0.4599 Epoch: 089, Loss: 0.4573 Epoch: 090, Loss: 0.4547 Epoch: 091, Loss: 0.4522 Epoch: 092, Loss: 0.4498 Epoch: 093, Loss: 0.4474 Epoch: 094, Loss: 0.4450 Epoch: 095, Loss: 0.4427 Epoch: 096, Loss: 0.4405 Epoch: 097, Loss: 0.4383 Epoch: 098, Loss: 0.4361 Epoch: 099, Loss: 0.4340 Epoch: 100, Loss: 0.4319
test_acc = test()
print(f'Accuracy after training: {test_acc:.4f}')
model.eval()
out = model(data.x)
visualize(out, color=data.y)
Accuracy after training: 0.6850
We'll build a simple model and prepare for training.
from torch_geometric.nn import ChebConv
import torch
import torch.nn as nn
import torch.nn.functional as F
class ChebNet(nn.Module):
def __init__(self, num_features, num_hidden_features ,num_output_classes, K=2):
super(ChebNet, self).__init__()
# K defines the order of the Chebyshev polynomials
self.conv1 = ChebConv(num_features, num_hidden_features, K=K)
self.conv2 = ChebConv(num_hidden_features, num_output_classes, K=K)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return x
model = ChebNet(num_features=data.num_features, num_hidden_features=16 ,num_output_classes=7, K=3)
print("num of parameters is ", count_parameters(model))
print("K * (num_features * num_hidden_features ) + num_hidden_features(bias) + K * (num_hidden_features * num_output_classes ) + num_output_classes(bias)")
print("2 * (1433 * 16 ) + 16 + 2 * (16 * 7 ) + 7")
optimizer = torch.optim.Adam(model.parameters(), lr=0.1, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
def train():
model.train()
optimizer.zero_grad() # Clear gradients.
out = model(data) # Perform a single forward pass.
loss = criterion(out[data.train_mask], data.y[data.train_mask]) # Compute the loss solely based on the training nodes.
loss.backward() # Derive gradients.
optimizer.step() # Update parameters based on gradients.
return loss
def test():
model.eval()
out = model(data)
pred = out.argmax(dim=1) # Use the class with highest probability.
test_correct = pred[data.test_mask] == data.y[data.test_mask] # Check against ground-truth labels.
test_acc = int(test_correct.sum()) / int(data.test_mask.sum()) # Derive ratio of correct predictions.
return test_acc
test_acc = test()
print(f'Accuracy before training: {test_acc:.4f}')
# Training loop (no changes needed here)
for epoch in range(1, 101):
loss = train()
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
num of parameters is 69143 K * (num_features * num_hidden_features ) + num_hidden_features(bias) + K * (num_hidden_features * num_output_classes ) + num_output_classes(bias) 2 * (1433 * 16 ) + 16 + 2 * (16 * 7 ) + 7 Accuracy before training: 0.1070 Epoch: 001, Loss: 1.9472 Epoch: 002, Loss: 1.7669 Epoch: 003, Loss: 1.5121 Epoch: 004, Loss: 1.0400 Epoch: 005, Loss: 0.7446 Epoch: 006, Loss: 0.4753 Epoch: 007, Loss: 0.3544 Epoch: 008, Loss: 0.2709 Epoch: 009, Loss: 0.2102 Epoch: 010, Loss: 0.1777 Epoch: 011, Loss: 0.1164 Epoch: 012, Loss: 0.1208 Epoch: 013, Loss: 0.1203 Epoch: 014, Loss: 0.1178 Epoch: 015, Loss: 0.1245 Epoch: 016, Loss: 0.1761 Epoch: 017, Loss: 0.1132 Epoch: 018, Loss: 0.1269 Epoch: 019, Loss: 0.1309 Epoch: 020, Loss: 0.1318 Epoch: 021, Loss: 0.1090 Epoch: 022, Loss: 0.1243 Epoch: 023, Loss: 0.1207 Epoch: 024, Loss: 0.1364 Epoch: 025, Loss: 0.1017 Epoch: 026, Loss: 0.1025 Epoch: 027, Loss: 0.0941 Epoch: 028, Loss: 0.1132 Epoch: 029, Loss: 0.0687 Epoch: 030, Loss: 0.0704 Epoch: 031, Loss: 0.1150 Epoch: 032, Loss: 0.1099 Epoch: 033, Loss: 0.0802 Epoch: 034, Loss: 0.0941 Epoch: 035, Loss: 0.0763 Epoch: 036, Loss: 0.0804 Epoch: 037, Loss: 0.0929 Epoch: 038, Loss: 0.0849 Epoch: 039, Loss: 0.0770 Epoch: 040, Loss: 0.0904 Epoch: 041, Loss: 0.0934 Epoch: 042, Loss: 0.0691 Epoch: 043, Loss: 0.0658 Epoch: 044, Loss: 0.0850 Epoch: 045, Loss: 0.0898 Epoch: 046, Loss: 0.0914 Epoch: 047, Loss: 0.0902 Epoch: 048, Loss: 0.1053 Epoch: 049, Loss: 0.0970 Epoch: 050, Loss: 0.0782 Epoch: 051, Loss: 0.0790 Epoch: 052, Loss: 0.1000 Epoch: 053, Loss: 0.0894 Epoch: 054, Loss: 0.0847 Epoch: 055, Loss: 0.0848 Epoch: 056, Loss: 0.0840 Epoch: 057, Loss: 0.0631 Epoch: 058, Loss: 0.0886 Epoch: 059, Loss: 0.0938 Epoch: 060, Loss: 0.1443 Epoch: 061, Loss: 0.0830 Epoch: 062, Loss: 0.0893 Epoch: 063, Loss: 0.0681 Epoch: 064, Loss: 0.0764 Epoch: 065, Loss: 0.0814 Epoch: 066, Loss: 0.0897 Epoch: 067, Loss: 0.0847 Epoch: 068, Loss: 0.1278 Epoch: 069, Loss: 0.0694 Epoch: 070, Loss: 0.1227 Epoch: 071, Loss: 0.0773 Epoch: 072, Loss: 0.0744 Epoch: 073, Loss: 0.0959 Epoch: 074, Loss: 0.0683 Epoch: 075, Loss: 0.0700 Epoch: 076, Loss: 0.0783 Epoch: 077, Loss: 0.0748 Epoch: 078, Loss: 0.0941 Epoch: 079, Loss: 0.0898 Epoch: 080, Loss: 0.0946 Epoch: 081, Loss: 0.0907 Epoch: 082, Loss: 0.1114 Epoch: 083, Loss: 0.0826 Epoch: 084, Loss: 0.0931 Epoch: 085, Loss: 0.0821 Epoch: 086, Loss: 0.0624 Epoch: 087, Loss: 0.0993 Epoch: 088, Loss: 0.0815 Epoch: 089, Loss: 0.0887 Epoch: 090, Loss: 0.0933 Epoch: 091, Loss: 0.0669 Epoch: 092, Loss: 0.0784 Epoch: 093, Loss: 0.0730 Epoch: 094, Loss: 0.0667 Epoch: 095, Loss: 0.0786 Epoch: 096, Loss: 0.0713 Epoch: 097, Loss: 0.1065 Epoch: 098, Loss: 0.0816 Epoch: 099, Loss: 0.1041 Epoch: 100, Loss: 0.0652
test_acc = test()
print(f'Accuracy before training: {test_acc:.4f}')
model.eval()
out = model(data)
visualize(out, color=data.y)
Accuracy before training: 0.8110
checking new way to SpectralGNN
from torch_geometric.nn import GCNConv
class GCN(nn.Module):
def __init__(self, num_features, num_hidden_features ,num_output_classes):
super().__init__()
# K defines the order of the Chebyshev polynomials
self.conv1 = GCNConv(num_features, num_hidden_features)
self.conv2 = GCNConv(num_hidden_features, num_output_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return x
model = GCN(num_features=data.num_features, num_hidden_features=16 ,num_output_classes=7)
print("num of parameters is ", count_parameters(model))
print("K * (num_features * num_hidden_features ) + num_hidden_features(bias) + K * (num_hidden_features * num_output_classes ) + num_output_classes(bias)")
print("1 * (1433 * 16 ) + 16 + 2 * (16 * 7 ) + 7")
optimizer = torch.optim.Adam(model.parameters(), lr=0.1, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
def train():
model.train()
optimizer.zero_grad() # Clear gradients.
out = model(data) # Perform a single forward pass.
loss = criterion(out[data.train_mask], data.y[data.train_mask]) # Compute the loss solely based on the training nodes.
loss.backward() # Derive gradients.
optimizer.step() # Update parameters based on gradients.
return loss
def test():
model.eval()
out = model(data)
pred = out.argmax(dim=1) # Use the class with highest probability.
test_correct = pred[data.test_mask] == data.y[data.test_mask] # Check against ground-truth labels.
test_acc = int(test_correct.sum()) / int(data.test_mask.sum()) # Derive ratio of correct predictions.
return test_acc
test_acc = test()
print(f'Accuracy before training: {test_acc:.4f}')
# Training loop (no changes needed here)
for epoch in range(1, 101):
loss = train()
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
num of parameters is 23063 K * (num_features * num_hidden_features ) + num_hidden_features(bias) + K * (num_hidden_features * num_output_classes ) + num_output_classes(bias) 1 * (1433 * 16 ) + 16 + 2 * (16 * 7 ) + 7 Accuracy before training: 0.1740 Epoch: 001, Loss: 1.9461 Epoch: 002, Loss: 1.9285 Epoch: 003, Loss: 1.8893 Epoch: 004, Loss: 1.8335 Epoch: 005, Loss: 1.7839 Epoch: 006, Loss: 1.6736 Epoch: 007, Loss: 1.6047 Epoch: 008, Loss: 1.5213 Epoch: 009, Loss: 1.4257 Epoch: 010, Loss: 1.3381 Epoch: 011, Loss: 1.1995 Epoch: 012, Loss: 1.1453 Epoch: 013, Loss: 1.0280 Epoch: 014, Loss: 0.9573 Epoch: 015, Loss: 0.9076 Epoch: 016, Loss: 0.7607 Epoch: 017, Loss: 0.7629 Epoch: 018, Loss: 0.7182 Epoch: 019, Loss: 0.6652 Epoch: 020, Loss: 0.6357 Epoch: 021, Loss: 0.5838 Epoch: 022, Loss: 0.5251 Epoch: 023, Loss: 0.5152 Epoch: 024, Loss: 0.5067 Epoch: 025, Loss: 0.5517 Epoch: 026, Loss: 0.4870 Epoch: 027, Loss: 0.4549 Epoch: 028, Loss: 0.4687 Epoch: 029, Loss: 0.4346 Epoch: 030, Loss: 0.4454 Epoch: 031, Loss: 0.3637 Epoch: 032, Loss: 0.3779 Epoch: 033, Loss: 0.4025 Epoch: 034, Loss: 0.3954 Epoch: 035, Loss: 0.3589 Epoch: 036, Loss: 0.3606 Epoch: 037, Loss: 0.3420 Epoch: 038, Loss: 0.3828 Epoch: 039, Loss: 0.2852 Epoch: 040, Loss: 0.3112 Epoch: 041, Loss: 0.3099 Epoch: 042, Loss: 0.2946 Epoch: 043, Loss: 0.3311 Epoch: 044, Loss: 0.3017 Epoch: 045, Loss: 0.4111 Epoch: 046, Loss: 0.3963 Epoch: 047, Loss: 0.3419 Epoch: 048, Loss: 0.3083 Epoch: 049, Loss: 0.3219 Epoch: 050, Loss: 0.2596 Epoch: 051, Loss: 0.3269 Epoch: 052, Loss: 0.2950 Epoch: 053, Loss: 0.2839 Epoch: 054, Loss: 0.2728 Epoch: 055, Loss: 0.2661 Epoch: 056, Loss: 0.2691 Epoch: 057, Loss: 0.3902 Epoch: 058, Loss: 0.3534 Epoch: 059, Loss: 0.2678 Epoch: 060, Loss: 0.2971 Epoch: 061, Loss: 0.3172 Epoch: 062, Loss: 0.2946 Epoch: 063, Loss: 0.2815 Epoch: 064, Loss: 0.3261 Epoch: 065, Loss: 0.2739 Epoch: 066, Loss: 0.2869 Epoch: 067, Loss: 0.2689 Epoch: 068, Loss: 0.2824 Epoch: 069, Loss: 0.2988 Epoch: 070, Loss: 0.2621 Epoch: 071, Loss: 0.2647 Epoch: 072, Loss: 0.2407 Epoch: 073, Loss: 0.2928 Epoch: 074, Loss: 0.2723 Epoch: 075, Loss: 0.2728 Epoch: 076, Loss: 0.3073 Epoch: 077, Loss: 0.2575 Epoch: 078, Loss: 0.3179 Epoch: 079, Loss: 0.2836 Epoch: 080, Loss: 0.2821 Epoch: 081, Loss: 0.2258 Epoch: 082, Loss: 0.2704 Epoch: 083, Loss: 0.2682 Epoch: 084, Loss: 0.2907 Epoch: 085, Loss: 0.2455 Epoch: 086, Loss: 0.3026 Epoch: 087, Loss: 0.2714 Epoch: 088, Loss: 0.2665 Epoch: 089, Loss: 0.3013 Epoch: 090, Loss: 0.2802 Epoch: 091, Loss: 0.2880 Epoch: 092, Loss: 0.2619 Epoch: 093, Loss: 0.2626 Epoch: 094, Loss: 0.2324 Epoch: 095, Loss: 0.2254 Epoch: 096, Loss: 0.2690 Epoch: 097, Loss: 0.2705 Epoch: 098, Loss: 0.2931 Epoch: 099, Loss: 0.2474 Epoch: 100, Loss: 0.2449
test_acc = test()
print(f'Accuracy after training: {test_acc:.4f}')
model.eval()
out = model(data)
visualize(out, color=data.y)
Accuracy after training: 0.7770